import os
import sys

sys.path.append("../../")
os.getcwd()

import numpy as np
from torch.utils.data import Dataset


class NewMixBackdoorDataset(Dataset):
    def __init__(self, clean_dataset_with_transform, bd_dataset_with_transform):
        self.clean_dataset_with_transform = clean_dataset_with_transform
        self.bd_dataset_with_transform = bd_dataset_with_transform
        self.origin_idx = np.arange(len(clean_dataset_with_transform))

    def __getitem__(self, i):
        clean_img, clean_label = self.clean_dataset_with_transform[i]
        bd_img, bd_label, _, _, _ = self.bd_dataset_with_transform[i]
        origin_idx = self.origin_idx[i]
        return bd_img, bd_label, clean_img, clean_label, origin_idx

    def __len__(self):
        return len(self.clean_dataset_with_transform)


class SplitDataset(Dataset):
    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = indices
        self.origin_idx = np.arange(len(indices))

    def __getitem__(self, i):
        bd_img, bd_label, clean_img, clean_label, old_origin_idx = self.dataset[int(self.indices[i])]
        new_origin_idx = self.origin_idx[i]
        return bd_img, bd_label, clean_img, clean_label, new_origin_idx

    def __len__(self):
        return len(self.indices)
